Conditional Monge Gap: JIT-compatible loss + training estimator#679
Open
DhruvaRajwade wants to merge 6 commits intoott-jax:mainfrom
Open
Conditional Monge Gap: JIT-compatible loss + training estimator#679DhruvaRajwade wants to merge 6 commits intoott-jax:mainfrom
DhruvaRajwade wants to merge 6 commits intoott-jax:mainfrom
Conversation
…IT compatibility The original cmonge_gap_from_samples used a Python for-loop over jnp.unique(condition), which breaks JAX JIT compilation since jnp.unique returns a dynamically-sized array. Replace with _segment_interface which pads per-condition point clouds to a fixed max_measure_size and vmaps the per-segment Monge gap computation. This makes the function fully JIT-compatible. The eval_fn computes per-segment: displacement_cost - ent_reg_cost, matching the definition in monge_gap_from_samples. Padded entries have zero weight and do not affect the result. New parameters num_segments and max_measure_size are required for JIT (consistent with segment_sinkhorn API). Cost function parameters (cost_fn, epsilon, relative_epsilon, scale_cost) are now explicit rather than passed through **kwargs.
Add the estimator class that mirrors MongeGapEstimator but handles condition-dependent neural maps T(x, c) with per-condition Monge gap regularization via cmonge_gap_from_samples. Changes: - ConditionalMongeGapEstimator in conditional_monge_gap.py: training loop with 3-arg regularizer(source, mapped, labels), 4-iterator batch protocol, JIT-compiled step function - ConditionalDataset + create_conditional_gaussian_mixture_samplers in datasets.py: synchronized 4-iterator data pipeline for testing - Export conditional_perturbation_network from networks/__init__ - 16 tests: 8 unit tests for cmonge_gap_from_samples (non-negativity, JIT consistency, loop baseline match, identity vs random, cost fns, return shape) + 2 integration tests for the estimator (convergence, no-regularizer mode)
… tests Add 5 new tests to TestConditionalMongeGap: - test_non_negativity_neural_map: PotentialMLP-based targets - test_different_costs_give_different_values: PNormP, RegTICost(L1), RegTICost(STVS) - test_uniform_conditions_equals_averaged_monge_gap: exact equivalence proof - test_unequal_conditions_shifts_average: structural properties with padding - test_per_condition_gaps_reflect_difficulty: monotonic gap ordering
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #679 +/- ##
==========================================
+ Coverage 87.35% 87.39% +0.04%
==========================================
Files 82 84 +2
Lines 8476 8690 +214
Branches 581 600 +19
==========================================
+ Hits 7404 7595 +191
- Misses 922 937 +15
- Partials 150 158 +8
🚀 New features to boost your workflow:
|
- Add logger.warning in cmonge_gap_from_samples when any condition is padded >10x its actual size (skipped under JIT via try/except) - Add runtime timing to test_uniform_conditions_equals_averaged_monge_gap comparing segmented vs loop performance - Rewrite PR_MESSAGE.md to ~half page with concise overview and tutorial plot
ae24357 to
3f8e6a1
Compare
|
@michalk8 whenever you have time, feel free to let us know any feedback on this PR! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Conditional Monge Gap: JIT-compatible loss + training estimator
The CMonge paper has been accepted to Nature Machine Intelligence. This PR picks up from #605 and fixes the JIT issue + adds a training estimator.
Changes
cmonge_gap_from_samples-- replaced thejnp.uniqueloop (breaksjax.jit) with_segment_interface(pad + vmap), following @michalk8's suggestion. Two new required-for-JIT params:num_segments,max_measure_size(consistent withsegment_sinkhorn). Alogger.warningfires when any condition is padded >10x its actual size, since heavy padding can cause small numerical differences vs non-padded Sinkhorn.ConditionalMongeGapEstimator-- training wrapper mirroringMongeGapEstimatorfor conditional mapsT(x, c):A potential precision tradeoff remains: When all conditions have the same
n_k, the segment-based result matchesmonge_gap_from_samplesto ~1e-7. With unequaln_k, smaller conditions are zero-padded tomax_measure_size, which changes the Sinkhorn geometry slightly. This is inherent to the segment/vmap approach -- the trade-off is JIT compatibility.Tests (26 passing)
All tests mirror
monge_gap_test.pypatterns wherever applicable: non-negativity (random + neural map targets), JIT consistency, cost function variants, estimator convergence. Three additional equivalence tests verifycmonge_gap = mean(monge_gap_k)for equal-size conditions, document the padding effect for unequal sizes, and check monotonic gap ordering by difficulty.pytest tests/neural/methods/conditional_monge_gap_test.py -v # 26 tests, ~80sUsage example